注意力评分函数

假设有一个查询 \(\mathbf{q}\in\mathbb{R}^{q}\)\(m\) 个键-值对 \((\mathbf{k}_{1}, \mathbf{v}_{1}),...,(\mathbf{k}_{m}, \mathbf{v}_{m})\),其中 \(\mathbf{k}_{i}\in\mathbb{R}^{k}\)\(\mathbf{v}_{i}\in\mathbb{R}^{v}\) 。注意力汇聚函数就被表示成值的加权和:

\[f(\mathbf{q}, (\mathbf{k}_{1}, \mathbf{v}_{1}),...,(\mathbf{k}_{m}, \mathbf{v}_{m})) = \sum_{i=1}^{m}\frac{\exp(a(\mathbf{q}, \mathbf{k}_{i}))}{\sum_{j=1}^{m}\exp(a(\mathbf{q}, \mathbf{k}_{j}))}\mathbf{v}_{i}\in\mathbb{R}^{v}\]

注意力评分函数 \(a\) 将查询 \(\mathbf{q}\) 和键 \(\mathbf{k}_{i}\) 两个向量映射成了标量 \(a(\mathbf{q}, \mathbf{k}_{i})\),表示键对值的注意力。

有很多种不同的注意力评分函数,本节介绍其中较流行的两种:加性注意力(additive attention)和缩放点积注意力(scaled dot-product attention)。

jupyter

带遮蔽的softmax

正如上面公式中所示,softmax运算用于输出一个概率分布作为注意力权重。

但是在很多时候,并非所有的值都应被纳入注意力汇聚中,比如说文本序列中的填充词元。

下面的函数实现了这样的遮蔽softmax。

import torch
from torch import nn
import d2l


#@save
def masked_softmax(X, valid_lens):
    """实现带遮蔽的softmax"""
    # shape of X: (`batch_size`, no. of queries, no. of key-value pairs)
    # shape of valid_lens: None or (`batch_size`,) or (`batch_size`, no. of queries)
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        # 将valid_lens转化为(`batch_size` * no. of queries)
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 在最后的轴上,遮蔽的元素被替换成一个非常大的负值,其指数约为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, 
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

加性注意力

一般来说,当查询和键是不同长度的矢量时,可以通过加性注意力作为评分函数。

给定查询 \(\mathbf{q} \in \mathbb{R}^{q}\) 和键 \(\mathbf{k} \in \mathbb{R}^{k}\),加性注意力评分函数:

\[a(\mathbf{q}, \mathbf{k}) = \mathbf{w}_{h}^{T}\mbox{tanh}(\mathbf{W}_{q}\mathbf{q} + \mathbf{W}_{k}\mathbf{k})\]

其中可学习的参数是 \(\mathbf{W}_{q} \in \mathbb{R}^{h\times{q}}, \mathbf{W}_{k} \in \mathbb{R}^{h\times{k}}\)\(\mathbf{w}_{h} \in \mathbb{R}^{h}\)

#@save
class AdditiveAttention(nn.Module):
    """加性注意力"""
    def __init__(self, key_size, query_size, num_hiddens, dropout):
        super(AdditiveAttention, self).__init__()
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        # shape of queries: (`batch_size`, no. of queries, `query_size`)
        # shape of keys: (`batch_size`, no. of key-value pairs, `key_size`)
        # shape of values: (`batch_size`, no. of key-value pairs, `value_size`)
        # shape of valid_lens: either (`batch_size`,) or (`batch_size`, no. of queries)
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, 
        # shape of queries: (`batch_size`, no. of queries, 1, `num_hiddens`)
        # shape of keys: (`batch_size`, 1, no. of key-value pairs, `num_hiddens`). 
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # Shape of `scores`: (`batch_size`, no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Output shape: (`batch_size`, no. of queries, `value_size`)
        return torch.bmm(self.dropout(self.attention_weights), values)

缩放点积注意力

使用点积可以得到计算效率更高的评分函数,但是点积操作需要查询和键具有相同的长度 \(d\)

假设查询和键的元素都是独立的随机变量,均值为0方差为1,那么两个向量的点积均值为0方差为 \(d\)。为了确保无论向量长度如何,注意力评分的方差均为1,点积需除以 \(\sqrt{d}\)

\[a(\mathbf{q}, \mathbf{k}) = \frac{\mathbf{q}^{T}\mathbf{k}}{\sqrt{d}}\]

从小批量的角度,假设有 \(n\) 个查询 \(\mathbf{Q}\in\mathbb{R}^{n\times{d}}\)\(m\) 个键-值对 \(\mathbf{K}\in\mathbb{R}^{m\times{d}}, \mathbf{V}\in\mathbb{R}^{m\times{v}}\),缩放点积注意力为:

\[\mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{T}}{\sqrt{d}}\right)\mathbf{V} \in \mathbb{R}^{n\times{v}}\]
#@save
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # Shape of queries: (`batch_size`, no. of queries, `d`)
        # Shape of keys: (`batch_size`, no. of key-value pairs, `d`)
        # Shape of values: (`batch_size`, no. of key-value pairs, `value_size`)
        # Shape of valid_lens: (`batch_size`,) or (`batch_size`, no. of queries)
        d = queries.shape[-1]
        # Shape of `scores`: (`batch_size`, no. of queries, no. of key-value pairs)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Output shape: (`batch_size`, no. of queries, `value_size`)
        return torch.bmm(self.dropout(self.attention_weights), values)